{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of Deep Convolutional GANs\n",
    "Reference: https://arxiv.org/pdf/1511.06434.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Run the comment below only when using Google Colab\n",
    "# !pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import datetime\n",
    "import os, sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from matplotlib.pyplot import imshow, imsave\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "MODEL_NAME = 'Conditional-DCGAN'\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def to_onehot(x, num_classes=10):\n",
    "    assert isinstance(x, int) or isinstance(x, (torch.LongTensor, torch.cuda.LongTensor))\n",
    "    if isinstance(x, int):\n",
    "        c = torch.zeros(1, num_classes).long()\n",
    "        c[0][x] = 1\n",
    "    else:\n",
    "        x = x.cpu()\n",
    "        c = torch.LongTensor(x.size(0), num_classes)\n",
    "        c.zero_()\n",
    "        c.scatter_(1, x, 1) # dim, index, src value\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_sample_image(G, n_noise=100):\n",
    "    \"\"\"\n",
    "        save sample 100 images\n",
    "    \"\"\"\n",
    "    img = np.zeros([280, 280])\n",
    "    for j in range(10):\n",
    "        c = torch.zeros([10, 10]).to(DEVICE)\n",
    "        c[:, j] = 1\n",
    "        z = torch.randn(10, n_noise).to(DEVICE)\n",
    "        y_hat = G(z,c).view(10, 28, 28)\n",
    "        result = y_hat.cpu().data.numpy()\n",
    "        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    \"\"\"\n",
    "        Convolutional Discriminator for MNIST\n",
    "    \"\"\"\n",
    "    def __init__(self, in_channel=1, input_size=784, condition_size=10, num_classes=1):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.transform = nn.Sequential(\n",
    "            nn.Linear(input_size+condition_size, 784),\n",
    "            nn.LeakyReLU(0.2),\n",
    "        )\n",
    "        self.conv = nn.Sequential(\n",
    "            # 28 -> 14\n",
    "            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            # 14 -> 7\n",
    "            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            # 7 -> 4\n",
    "            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.AvgPool2d(4),\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            # reshape input, 128 -> 1\n",
    "            nn.Linear(128, 1),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "    \n",
    "    def forward(self, x, c=None):\n",
    "        # x: (N, 1, 28, 28), c: (N, 10)\n",
    "        x, c = x.view(x.size(0), -1), c.float() # may not need\n",
    "        v = torch.cat((x, c), 1) # v: (N, 794)\n",
    "        y_ = self.transform(v) # (N, 784)\n",
    "        y_ = y_.view(y_.shape[0], 1, 28, 28) # (N, 1, 28, 28)\n",
    "        y_ = self.conv(y_)\n",
    "        y_ = y_.view(y_.size(0), -1)\n",
    "        y_ = self.fc(y_)\n",
    "        return y_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    \"\"\"\n",
    "        Convolutional Generator for MNIST\n",
    "    \"\"\"\n",
    "    def __init__(self, input_size=100, condition_size=10):\n",
    "        super(Generator, self).__init__()\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(input_size+condition_size, 4*4*512),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.conv = nn.Sequential(\n",
    "            # input: 4 by 4, output: 7 by 7\n",
    "            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(),\n",
    "            # input: 7 by 7, output: 14 by 14\n",
    "            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(),\n",
    "            # input: 14 by 14, output: 28 by 28\n",
    "            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        \n",
    "    def forward(self, x, c):\n",
    "        # x: (N, 100), c: (N, 10)\n",
    "        x, c = x.view(x.size(0), -1), c.float() # may not need\n",
    "        v = torch.cat((x, c), 1) # v: (N, 110)\n",
    "        y_ = self.fc(v)\n",
    "        y_ = y_.view(y_.size(0), 512, 4, 4)\n",
    "        y_ = self.conv(y_) # (N, 28, 28)\n",
    "        return y_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "D = Discriminator().to(DEVICE)\n",
    "G = Generator().to(DEVICE)\n",
    "# D.load_state_dict('D_dc.pkl')\n",
    "# G.load_state_dict('G_dc.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize((0.5,), (0.5,))]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dataset = datasets.FashionMNIST(root='fashion_data/', train=True, transform=transform, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = dataset.targets == 9\n",
    "\n",
    "dataset.data = dataset.data[idx]\n",
    "dataset.targets = dataset.targets[idx]\n",
    "\n",
    "data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "criterion = nn.BCELoss()\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=0.0005, betas=(0.5, 0.999))\n",
    "G_opt = torch.optim.Adam(G.parameters(), lr=0.0005, betas=(0.5, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "max_epoch = 100 # need more than 20 epochs for training generator\n",
    "step = 0\n",
    "n_critic = 1 # for training more k steps about Discriminator\n",
    "n_noise = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real\n",
    "D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0/100, Step: 0, D Loss: 1.3915491104125977, G Loss: 0.782874584197998\n",
      "Epoch: 2/100, Step: 500, D Loss: 0.10102915018796921, G Loss: 3.124659538269043\n",
      "Epoch: 5/100, Step: 1000, D Loss: 0.026514962315559387, G Loss: 4.6337385177612305\n",
      "Epoch: 8/100, Step: 1500, D Loss: 0.00831749476492405, G Loss: 5.531192779541016\n",
      "Epoch: 10/100, Step: 2000, D Loss: 0.005628030747175217, G Loss: 6.121631622314453\n",
      "Epoch: 13/100, Step: 2500, D Loss: 1.2646050453186035, G Loss: 0.5767863988876343\n",
      "Epoch: 16/100, Step: 3000, D Loss: 1.146709680557251, G Loss: 0.8445827960968018\n",
      "Epoch: 18/100, Step: 3500, D Loss: 1.3093414306640625, G Loss: 0.7794418334960938\n",
      "Epoch: 21/100, Step: 4000, D Loss: 1.3523428440093994, G Loss: 1.1322627067565918\n",
      "Epoch: 24/100, Step: 4500, D Loss: 1.2066881656646729, G Loss: 0.8252110481262207\n",
      "Epoch: 26/100, Step: 5000, D Loss: 1.3139939308166504, G Loss: 0.6843364834785461\n",
      "Epoch: 29/100, Step: 5500, D Loss: 1.286057472229004, G Loss: 0.8596994876861572\n",
      "Epoch: 32/100, Step: 6000, D Loss: 1.2956230640411377, G Loss: 0.681969404220581\n",
      "Epoch: 34/100, Step: 6500, D Loss: 1.4387595653533936, G Loss: 0.9044321775436401\n",
      "Epoch: 37/100, Step: 7000, D Loss: 1.4515787363052368, G Loss: 0.7265660166740417\n",
      "Epoch: 40/100, Step: 7500, D Loss: 1.4333388805389404, G Loss: 0.6892476081848145\n",
      "Epoch: 42/100, Step: 8000, D Loss: 1.3588316440582275, G Loss: 0.7377218008041382\n",
      "Epoch: 45/100, Step: 8500, D Loss: 1.342092514038086, G Loss: 0.6939644813537598\n",
      "Epoch: 48/100, Step: 9000, D Loss: 1.3554030656814575, G Loss: 0.6763062477111816\n",
      "Epoch: 50/100, Step: 9500, D Loss: 1.3818013668060303, G Loss: 0.7344485521316528\n",
      "Epoch: 53/100, Step: 10000, D Loss: 1.368570327758789, G Loss: 0.7152159214019775\n",
      "Epoch: 56/100, Step: 10500, D Loss: 1.4001199007034302, G Loss: 0.7099530696868896\n",
      "Epoch: 58/100, Step: 11000, D Loss: 1.4069771766662598, G Loss: 0.7039815187454224\n",
      "Epoch: 61/100, Step: 11500, D Loss: 1.3724676370620728, G Loss: 0.7245895862579346\n",
      "Epoch: 64/100, Step: 12000, D Loss: 1.3566913604736328, G Loss: 0.7094602584838867\n",
      "Epoch: 66/100, Step: 12500, D Loss: 1.382453441619873, G Loss: 0.7114216089248657\n",
      "Epoch: 69/100, Step: 13000, D Loss: 1.3584072589874268, G Loss: 0.6859960556030273\n",
      "Epoch: 72/100, Step: 13500, D Loss: 1.3444697856903076, G Loss: 0.6737773418426514\n",
      "Epoch: 74/100, Step: 14000, D Loss: 1.4096062183380127, G Loss: 0.7171881794929504\n",
      "Epoch: 77/100, Step: 14500, D Loss: 1.3902841806411743, G Loss: 0.6806784868240356\n",
      "Epoch: 80/100, Step: 15000, D Loss: 1.3703713417053223, G Loss: 0.681758463382721\n",
      "Epoch: 82/100, Step: 15500, D Loss: 1.4426193237304688, G Loss: 0.658524215221405\n",
      "Epoch: 85/100, Step: 16000, D Loss: 1.3985086679458618, G Loss: 0.6854572296142578\n",
      "Epoch: 88/100, Step: 16500, D Loss: 1.3863615989685059, G Loss: 0.7071089744567871\n",
      "Epoch: 90/100, Step: 17000, D Loss: 1.3698880672454834, G Loss: 0.6927786469459534\n",
      "Epoch: 93/100, Step: 17500, D Loss: 1.4185028076171875, G Loss: 0.6984732747077942\n",
      "Epoch: 96/100, Step: 18000, D Loss: 1.3933994770050049, G Loss: 0.7460653781890869\n",
      "Epoch: 98/100, Step: 18500, D Loss: 1.3553040027618408, G Loss: 0.7139111757278442\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(max_epoch):\n",
    "    for idx, (images, labels) in enumerate(data_loader):\n",
    "        # Training Discriminator\n",
    "        x = images.to(DEVICE)\n",
    "        y = labels.view(batch_size, 1)\n",
    "        y = to_onehot(y).to(DEVICE)\n",
    "        x_outputs = D(x, y)\n",
    "        D_x_loss = criterion(x_outputs, D_labels)\n",
    "\n",
    "        z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
    "        z_outputs = D(G(z, y), y)\n",
    "        D_z_loss = criterion(z_outputs, D_fakes)\n",
    "        D_loss = D_x_loss + D_z_loss\n",
    "        \n",
    "        D.zero_grad()\n",
    "        D_loss.backward()\n",
    "        D_opt.step()\n",
    "\n",
    "        if step % n_critic == 0:\n",
    "            # Training Generator\n",
    "            z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
    "            z_outputs = D(G(z, y), y)\n",
    "            G_loss = criterion(z_outputs, D_labels)\n",
    "\n",
    "            D.zero_grad()\n",
    "            G.zero_grad()\n",
    "            G_loss.backward()\n",
    "            G_opt.step()\n",
    "        \n",
    "        if step % 500 == 0:\n",
    "            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))\n",
    "                \n",
    "        if step % 1000 == 0:\n",
    "            G.eval()\n",
    "            img = get_sample_image(G, n_noise)\n",
    "            imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
    "            G.train()\n",
    "            \n",
    "        step += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# generation to image\n",
    "G.eval()\n",
    "imshow(get_sample_image(G, n_noise), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def save_checkpoint(state, file_name='checkpoint.pth.tar'):\n",
    "    torch.save(state, file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Saving params.\n",
    "# torch.save(D.state_dict(), 'D_c.pkl')\n",
    "# torch.save(G.state_dict(), 'G_c.pkl')\n",
    "save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')\n",
    "save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
